{
  "cells": [
    {
      "cell_type": "code",
      "metadata": {},
      "source": [
        "# code by Tae Hwan Jung @graykode\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.nn.functional as F\n",
        "\n",
        "class TextCNN(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(TextCNN, self).__init__()\n",
        "        self.num_filters_total = num_filters * len(filter_sizes)\n",
        "        self.W = nn.Embedding(vocab_size, embedding_size)\n",
        "        self.Weight = nn.Linear(self.num_filters_total, num_classes, bias=False)\n",
        "        self.Bias = nn.Parameter(torch.ones([num_classes]))\n",
        "        self.filter_list = nn.ModuleList([nn.Conv2d(1, num_filters, (size, embedding_size)) for size in filter_sizes])\n",
        "\n",
        "    def forward(self, X):\n",
        "        embedded_chars = self.W(X) # [batch_size, sequence_length, sequence_length]\n",
        "        embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size]\n",
        "\n",
        "        pooled_outputs = []\n",
        "        for i, conv in enumerate(self.filter_list):\n",
        "            # conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]\n",
        "            h = F.relu(conv(embedded_chars))\n",
        "            # mp : ((filter_height, filter_width))\n",
        "            mp = nn.MaxPool2d((sequence_length - filter_sizes[i] + 1, 1))\n",
        "            # pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]\n",
        "            pooled = mp(h).permute(0, 3, 2, 1)\n",
        "            pooled_outputs.append(pooled)\n",
        "\n",
        "        h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]\n",
        "        h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)]\n",
        "        model = self.Weight(h_pool_flat) + self.Bias # [batch_size, num_classes]\n",
        "        return model\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    embedding_size = 2 # embedding size\n",
        "    sequence_length = 3 # sequence length\n",
        "    num_classes = 2 # number of classes\n",
        "    filter_sizes = [2, 2, 2] # n-gram windows\n",
        "    num_filters = 3 # number of filters\n",
        "\n",
        "    # 3 words sentences (=sequence_length is 3)\n",
        "    sentences = [\"i love you\", \"he loves me\", \"she likes baseball\", \"i hate you\", \"sorry for that\", \"this is awful\"]\n",
        "    labels = [1, 1, 1, 0, 0, 0]  # 1 is good, 0 is not good.\n",
        "\n",
        "    word_list = \" \".join(sentences).split()\n",
        "    word_list = list(set(word_list))\n",
        "    word_dict = {w: i for i, w in enumerate(word_list)}\n",
        "    vocab_size = len(word_dict)\n",
        "\n",
        "    model = TextCNN()\n",
        "\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "    optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
        "\n",
        "    inputs = torch.LongTensor([np.asarray([word_dict[n] for n in sen.split()]) for sen in sentences])\n",
        "    targets = torch.LongTensor([out for out in labels]) # To using Torch Softmax Loss function\n",
        "\n",
        "    # Training\n",
        "    for epoch in range(5000):\n",
        "        optimizer.zero_grad()\n",
        "        output = model(inputs)\n",
        "\n",
        "        # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)\n",
        "        loss = criterion(output, targets)\n",
        "        if (epoch + 1) % 1000 == 0:\n",
        "            print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
        "\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "    # Test\n",
        "    test_text = 'sorry hate you'\n",
        "    tests = [np.asarray([word_dict[n] for n in test_text.split()])]\n",
        "    test_batch = torch.LongTensor(tests)\n",
        "\n",
        "    # Predict\n",
        "    predict = model(test_batch).data.max(1, keepdim=True)[1]\n",
        "    if predict[0][0] == 0:\n",
        "        print(test_text,\"is Bad Mean...\")\n",
        "    else:\n",
        "        print(test_text,\"is Good Mean!!\")"
      ],
      "outputs": [],
      "execution_count": null
    }
  ],
  "metadata": {
    "anaconda-cloud": {},
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.1"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}